#!/usr/bin/env python3

# copy from timm.layers.patch_embed.py

from typing import Callable

import torch
import torch.nn as nn


class PatchEmbed(nn.Module):
    def __init__(
        self,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int] = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        norm_layer: Callable | None = None,
        bias: bool = True,
    ):
        super().__init__()

        self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size
        self.num_patches = self._init_img_size(img_size)

        self.patch_size = (
            (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
        )

        # layers
        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
        )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def _init_img_size(self, img_size: int | tuple[int, int]):
        assert self.patch_size

        grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
        num_patches = grid_size[0] * grid_size[1]
        return img_size, grid_size, num_patches

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.size()

        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2).contiguous()
        x = self.norm(x)

        return x
